import datetime
import logging
import os
from typing import Any

import pymssql
import pymssql._pymssql as PyMssql
from pymssql import Connection  # type: ignore

from unstract.connectors.constants import DatabaseTypeConstants
from unstract.connectors.databases.exceptions import (
    ColumnMissingException,
    InvalidSyntaxException,
)
from unstract.connectors.databases.exceptions_helper import ExceptionHelper
from unstract.connectors.databases.unstract_db import UnstractDB

logger = logging.getLogger(__name__)


class MSSQL(UnstractDB):
    def __init__(self, settings: dict[str, Any]):
        super().__init__("MSSQL")

        self.user = settings.get("user")
        self.password = settings.get("password")
        self.server = settings.get("server")
        self.port = settings.get("port")
        self.database = settings.get("database")

    @staticmethod
    def get_id() -> str:
        return "mssql|6c6af35c-9498-4bd6-9258-23b5337e068b"

    @staticmethod
    def get_name() -> str:
        return "MSSQL"

    @staticmethod
    def get_description() -> str:
        return "MSSQL Database"

    @staticmethod
    def get_icon() -> str:
        return "/icons/connector-icons/MSSQL.png"

    @staticmethod
    def get_json_schema() -> str:
        f = open(f"{os.path.dirname(__file__)}/static/json_schema.json")
        schema = f.read()
        f.close()
        return schema

    @staticmethod
    def can_write() -> bool:
        return True

    @staticmethod
    def can_read() -> bool:
        return True

    def sql_to_db_mapping(self, value: Any, column_name: str | None = None) -> str:
        """Gets the python datatype of value and converts python datatype to
        corresponding DB datatype.

        Args:
            value (str): python datatype
            column_name (str | None): name of the column being mapped

        Returns:
            str: database columntype
        """
        data_type = type(value)

        if data_type in (dict, list):
            return str(DatabaseTypeConstants.MSSQL_NVARCHAR_MAX)

        mapping = {
            str: DatabaseTypeConstants.MSSQL_NVARCHAR_MAX,
            int: DatabaseTypeConstants.MSSQL_INT,
            float: DatabaseTypeConstants.MSSQL_FLOAT,
            datetime.datetime: DatabaseTypeConstants.MSSQL_DATETIMEOFFSET,
        }
        return str(mapping.get(data_type, DatabaseTypeConstants.MSSQL_NVARCHAR_MAX))

    def get_engine(self) -> Connection:
        return pymssql.connect(  # type: ignore
            server=self.server,
            port=self.port,
            user=self.user,
            password=self.password,
            database=self.database,
        )

    def get_create_table_base_query(self, table: str) -> str:
        """Function to create a base create table sql query with proper schema support.

        Args:
            table (str): db-connector table name (supports schema.table format)

        Returns:
            str: generates a create sql base query with the constant columns
        """
        # Parse schema and table name for existence check
        if "." in table:
            # Handle schema.table format like "[schema].[table]"
            parts = table.rsplit(".", 1)
            schema_name, table_name = parts[0], parts[1]
            existence_check = (
                f"IF NOT EXISTS ("
                f"SELECT * FROM INFORMATION_SCHEMA.TABLES "
                f"WHERE TABLE_SCHEMA = '{schema_name}' AND TABLE_NAME = '{table_name}'"
                f")"
            )
        else:
            # Handle unqualified table names (default to dbo schema)
            existence_check = (
                f"IF NOT EXISTS ("
                f"SELECT * FROM INFORMATION_SCHEMA.TABLES "
                f"WHERE TABLE_SCHEMA = 'dbo' AND TABLE_NAME = '{table}'"
                f")"
            )

        sql_query = (
            f"{existence_check} "
            f" CREATE TABLE {table} "
            f"(id NVARCHAR(MAX), "
            f"created_by NVARCHAR(MAX), created_at DATETIMEOFFSET, "
            f"metadata NVARCHAR(MAX), "
            f"user_field_1 BIT DEFAULT 0, "
            f"user_field_2 INT DEFAULT 0, "
            f"user_field_3 NVARCHAR(MAX) DEFAULT NULL, "
            f"status NVARCHAR(10) CHECK (status IN ('ERROR', 'SUCCESS')), "
            f"error_message NVARCHAR(MAX), "
        )
        return sql_query

    def prepare_multi_column_migration(
        self, table_name: str, column_name: str
    ) -> list[str]:
        """Returns a list of ALTER TABLE statements for MSSQL column migration.

        MSSQL doesn't support adding multiple columns in a single ALTER TABLE statement,
        so we return a list of individual statements like Snowflake.
        """
        sql_statements = [
            f"ALTER TABLE {table_name} ADD {column_name}_v2 NVARCHAR(MAX)",
            f"ALTER TABLE {table_name} ADD metadata NVARCHAR(MAX)",
            f"ALTER TABLE {table_name} ADD user_field_1 BIT DEFAULT 0",
            f"ALTER TABLE {table_name} ADD user_field_2 INT DEFAULT 0",
            f"ALTER TABLE {table_name} ADD user_field_3 NVARCHAR(MAX) DEFAULT NULL",
            f"ALTER TABLE {table_name} ADD status NVARCHAR(10)",
            f"ALTER TABLE {table_name} ADD error_message NVARCHAR(MAX)",
        ]
        return sql_statements

    def execute_query(
        self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any
    ) -> None:
        """Executes create/insert query.

        Args:
            engine (Any): mssql client engine
            sql_query (str): sql create table/insert into table query
            sql_values (Any): sql data to be insertted

        Raises:
            InvalidSyntaxException: raised due to invalid syntax
            ColumnMissingException: raised due to missing columns in table query
        """
        table_name = kwargs.get("table_name", None)
        try:
            with engine.cursor() as cursor:
                if sql_values:
                    params = tuple(sql_values)
                    cursor.execute(sql_query, params)
                else:
                    cursor.execute(sql_query)
            engine.commit()
        except PyMssql.OperationalError as e:
            error_details = ExceptionHelper.extract_byte_exception(e=e)
            logger.error(
                f"Invalid syntax in creating/inserting mssql data: {error_details}"
            )
            raise InvalidSyntaxException(
                detail=error_details, database=self.database
            ) from e
        except PyMssql.ProgrammingError as e:
            error_details = ExceptionHelper.extract_byte_exception(e=e)
            logger.error(f"Column missing in inserting data: {error_details}")
            raise ColumnMissingException(
                detail=error_details,
                database=self.database,
                table_name=table_name,
            ) from e

    def get_information_schema(self, table_name: str) -> dict[str, str]:
        """Function to generate information schema with proper schema and database support.

        Args:
            table_name (str): db-connector table name (supports schema.table format)

        Returns:
            dict[str, str]: a dictionary contains db column name and
                db column types of corresponding table
        """
        table_name = str.lower(table_name)
        if "." in table_name:
            # Handle schema.table format
            parts = table_name.rsplit(".", 1)
            schema_name, table_only = parts[0], parts[1]
            query = (
                f"SELECT column_name, data_type FROM "
                f"information_schema.columns WHERE "
                f"table_schema = '{schema_name}' AND table_name = '{table_only}'"
            )
        else:
            # Handle unqualified table names (default to dbo)
            query = (
                f"SELECT column_name, data_type FROM "
                f"information_schema.columns WHERE "
                f"table_schema = 'dbo' AND table_name = '{table_name}'"
            )

        results = self.execute(query=query)
        column_types: dict[str, str] = self.get_db_column_types(
            columns_with_types=results
        )
        return column_types
